% Analyze the data from the spiking network for some task consisting of
% trials of equal duration.
%
%
%  function [rs sig_cells] = analyzepatterns( area_name, type_name, phase_start_msec, phase_end_msec , start_time_msecs, end_time_msecs,significance_level, save_seconds)
% 
%  phase_start_msec, phase_end_msec:   the start and
%		end times (inclusive) within the trial to analyze the data.  Start counting at 0 for the first msec.
%  start_time_msec, end_time_msec: start and end times of the experiments.
%
%  save_seconds = the number of seconds saved in 1 spike*.dat file; must
%   be an integer.
%  RETURNS:
%	matrix of probabilities from ranksum, comparing spike counts of neurons during the specified task phase time
%	for significantly different medians.  rs(neuron, pattern i , pattern j);
%
%  EXAMPLE:
%
% analyzePatterns('S1','p4',0,249,10000,29999) 
%
%  will analyze patterns
%  which are delivered between 10000 msec and 29999 msec into the
%  simulation. During each trial, the spikes will be analyzed only between
%  reletive time 0 and 249. 
%  The file patternhist.dat looks like this where each line is a trial and
%  the first entry is the start time of a trial in msec and the second
%  number is the pattern number in range [0,N], where N is the max number
%  of patterns used in the experiments (no gaps.)
% 
% 
% 0 0
% 250 1
% 500 2
% 750 3
% 1000 0
% 1250 1
% 1500 2
% 1750 3
% 2000 0
% 2250 1
% 2500 2
% 2750 3
% ...
% 
%

function [rs sig_cells] = analyzeVMAstats( area_name, type_name, phase_start_msec, phase_end_msec , start_time_msecs, end_time_msecs,significance_level, save_seconds)

global S
global mfr

mfr=[];

load ('patternhist.dat');

phase_start_msec = phase_start_msec + 1;
phase_end_msec = phase_end_msec + 1;


if nargin < 7
	significance_level=0.05;
    save_seconds = 1;
else 
    if nargin < 8
        save_seconds = 1;
    end
end


read_groups; % Assumes groups.dat has the neural area data


index = [];
for i=1:length(areas)
    if strcmp( deblank(areas{i}), area_name) && strcmp( type_name, deblank(celltype{i})) 
        index = i
        break;
    end
end
if isempty(index)
    disp([area_name ' ' type_name ' is not in groups.dat.']);
    return;
end

first = neuron_id(index,1) 
last  = neuron_id(index,2) 

num_neurons = max(neuron_id(:,2))+1

first = first + 1;
last = last + 1;
rows = la(index);
cols = lb(index);
N=rows*cols



num_patterns = max( patternhist(:,2))+1


count = zeros( 1, num_patterns );


% find the file we need to open
%filename = ['sdat' numTOstr( ceil( (start_time_msecs+phase_start_msec)/(save_seconds*1000))*save_seconds,6) '.dat' ]
%sdat = load(filename);
%S=sparse( sdat(:,2)+1, sdat(:,1)+1, 1);
filename = '';

trial_indices_to_analyze = find(patternhist(:,1)>=start_time_msecs & patternhist(:,1)<end_time_msecs);

for i = trial_indices_to_analyze'

	t = patternhist( i, 1 );

	% find the file we need to open
	current_filename = ['spikes' num2str( ceil( ( t+phase_start_msec)/(save_seconds*1000))*save_seconds*1000,6)  '.dat']

	% Load new filename if necessary.
	if strcmp(current_filename, filename) == 0
		filename = current_filename;
		sdat = load(filename);
        max_time_msec = ceil( ( t+phase_start_msec)/(save_seconds*1000))*save_seconds*1000
		S=sparse( sdat(:,2)+1, sdat(:,1)+1, 1, max(num_neurons,max(sdat(:,1)+1)), max(max_time_msec,max(sdat(:,1))) );

    end

	% If this pattern presentation is split across files, ignore it.	
	end_filename = ['spikes' num2str( ceil( ( t +phase_end_msec-1)/(save_seconds*1000))*save_seconds*1000,6)  '.dat'];
	if strcmp(end_filename, filename) == 0
		continue;
	end

	current_pattern = patternhist( i , 2 )+1;  
	count(current_pattern) = count(current_pattern) + 1;
	c=count(current_pattern);
	% Gather mean rates for stats.
    first
    last 
    (t+phase_start_msec)
    (t+phase_end_msec)
    disp('****************************')
	mfr(current_pattern).spikes(:, c ) = full(sum( S( first:last, (t+phase_start_msec):(t+phase_end_msec-1)), 2));
current_pattern
end

count

k=0;
figure(100+k); k= k + 1;


disp('Mean firing rate of cells to each pattern');
for i = 1:num_patterns


	%figure(100+k); k= k + 1;
    side = ceil(sqrt(num_patterns))
    subplot(side,side,i);
    
	imagesc(rot90(reshape( mean(mfr(i).spikes,2)*1000/(phase_end_msec-phase_start_msec), rows, cols )))
	colormap(gray);
    colorbar;
	title(['Mean rate (sp/sec) of cells in area ' area_name type_name ' to pattern ' num2str(i) ]);
	
end




% Show significantly preferential responses to each pattern.



for i = 1:num_patterns
	for j = 1:num_patterns
  	    if ( count(j) > 1 && count(i) > 1)
		for n = 1:N
			[junk significant]=ranksum( mfr( i ).spikes(n, : ) , mfr( j ).spikes(n, : ) ,significance_level);
            rs(n,i,j) = significant && mean(mfr( i ).spikes(n, : )) * 0.5 > mean(mfr( j ).spikes(n, : )); 
%if sum( mfr( i ).spikes(n, :) ) > 0
% [i j]
% mfr( i ).spikes(n, : ) 
% mfr( j ).spikes(n, : ) 
%end

		end
		num_significant(i,j)=sum(rs(:,i,j));
		%if j>i
		%	figure(100+k);
		%	imagesc(reshape( rs(:,i,j), rows, cols ))
		%	colormap(gray);
		%	title(['Cells in area ' areaname ' that distinguish pattern ' num2str(i) ' from pattern ' num2str(j) ' statistically']);
		%	k=k+1;
		%end
	    end
	end
	disp('.');
end

disp('number of cells which distinguish pattern "row" from pattern "col"');
num_significant

discriminators = zeros(1,num_patterns);
merged_discriminator_map = zeros(N,1);

for i = 1:num_patterns
	temp = ones(N,1);
	for j = 1:num_patterns

		if i ~= j
			temp = temp & rs(:,i,j);
		end

    end

    merged_discriminator_map = merged_discriminator_map + i*temp;
    discriminators(i) = sum(temp);

	figure(100+k); k= k + 1;

	%subplot(num_patterns,1,i);
	imagesc(rot90(reshape( temp, rows, cols )))
	colormap(gray);
	title(['Cells in area ' area_name type_name ' that distinguish pattern ' num2str(i) ' from ALL others statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec)]);
    size(temp)
	disp('Significant cell numbers (start counting at 0 as in C): ');
    sig_cells(i).list = (find(temp)+first-2)';
end

figure(100+k);k=k+1;
bar(discriminators);
title(['Cells in area ' area_name type_name ' that distinguish pattern ' num2str(i) ' from ALL others statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec)]);


figure(100+k);k=k+1;
imagesc(rot90(reshape( merged_discriminator_map, rows, cols )))
title(['Cells in area ' area_name type_name ' that distinguish each pattern from ALL others statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec)]);

merged_discriminator_map
for i = find(merged_discriminator_map')
   disp(['Responses for unit ' num2str(i)]);
   for j = 1:num_patterns 
      disp([j mfr(j).spikes(i,:)]); 
   end
end


% Display how each cell responds to each pattern.

figure(100+k); k= k + 1;
disp('Mean firing rate of each cell grouped by patterns');
cell = 1;
for i = 1:rows
    for j = 1:cols

        subplot(rows,cols,cell);
        temp = zeros(1,num_patterns);
        for pat = 1:num_patterns
            temp(pat)=mean(mfr(pat).spikes(cell),2)*1000/(phase_end_msec-phase_start_msec);
        end
        bar(temp);
        axis off;
        ylim([0 20]);
        xlim([1 num_patterns]);
        %imagesc(rot90(reshape( mean(mfr(i).spikes,2)*1000/(phase_end_msec-phase_start_msec), rows, cols )))

        cell = cell + 1;
    end
end

